# All helper functions 

@with_kw struct define_param
    level_fcost::Float64 = fixed_cost_param_df.level_fcost[1]
    scale_fcost::Float64 = fixed_cost_param_df.scale_fcost[1]
    connect_theta::Float64 = baseline_parameters_df.theta[1]
    connect_cost_level::Float64 = baseline_parameters_df.cost_level[1]
    connect_cost_elasticity::Float64 = baseline_parameters_df.cost_elasticity[1]
    connect_fixed_cost::Float64 = 0.0
    connect_rho::Float64 = baseline_parameters_df.rho[1]
    connect_proba_c::Float64 = baseline_parameters_df.proba_c[1]
    alpha_eps::Float64 = baseline_parameters_df.alpha_eps[1]
    beta_eps::Float64 = baseline_parameters_df.beta_eps[1]
    variance_eps_z::Float64 = baseline_parameters_df.variance_eps_z[1]
    sigma::Float64 = 1/(1-eta_tilde)
    proba_c::Float64 = baseline_parameters_df.proba_c[1]
end

# want to be able to index the structure
Base.getindex(x::define_param, i) = getproperty(x, fieldnames(typeof(x))[i])

### Get expected fixed costs ### 

function broadcast_inc_gamma(x)
	x = ifelse.(x .> 150, 150, x)
	x = ifelse.(x .< 1e-250, 1e-250, x)
	return sf_gamma_inc(0.0,x)
end

function compute_exp_fcost(;exit_proba,exp_VF,param)    
    return (beta .* exp_VF .* exit_proba .- param.scale_fcost .* broadcast_inc_gamma.(-log.(exit_proba)))
end 

### Get expected profits ### 

function compute_profits_NC(; w, Y, param, tax=nothing, z_grid = z_grid_baseline) 

    # assign stat tax 
    if tax == nothing
        tax = copy(vat_tax)
    end 

    # Step 1: compute z_star & x_star (both are functions of w & Y)
    x_bar = (Y^(1/param.sigma))
    z_star = (z_grid.^((param.sigma - 1)/param.sigma)) .* x_bar
    x_star = ( (((1-tax)*alpha_gross/(rental_rate - 1.0))^alpha_gross)*(((1-tax)*beta_gross/w)^beta_gross)*((gamma_gross)^gamma_gross) )

    # Step 2: compute revenue and profits 
    revenue = (z_star .* x_star).^(1/(1-eta_tilde))
    profits = (1-profit_tax) .* (1-eta_tilde) .* revenue 

    # Step 3: compute other objects needed 
    output_aggr = revenue 

    ### What else needs to be computed here? Output? Taxes? ... 

    # return results 
    return (
        revenue = revenue, expected_revenue = revenue, 
        profits = profits, expected_profits = profits, 
        output_aggr = output_aggr, expected_revenue_output = output_aggr)
end 

function compute_subsidy_baseline(; z_star_grid, epsilon_grid, param)

    #### Step 1: Preparation ####

    # get parameters 
    theta = param.connect_theta 
    cost_level = param.connect_cost_level
    cost_elast = param.connect_cost_elasticity 
    fixed_cost = param.connect_fixed_cost 

    # baseline x_star (enforcing w = 1.0) 
    x_star = ( (((1-vat_tax)*alpha_gross/(rental_rate - 1.0))^alpha_gross)*(((1-vat_tax)*beta_gross)^beta_gross)*((gamma_gross)^gamma_gross) )

    # get z_tilde_grid 
    z_tilde_grid = (z_star_grid .* x_star).^(1/(1-eta_tilde))
  
    # specify FOC 
    function rent_seeking_FOC(m_R, z_tilde_value, epsilon)
        return z_tilde_value .* ( (1.0 .+ (epsilon .*((m_R).^(theta)) .- cost_level .* (m_R.^cost_elast))).^elasticity_ratio ) .* ( (theta .* epsilon .* ((m_R).^(theta - 1.0)) .- cost_level .* cost_elast .* (m_R.^(cost_elast-1.0))) ) .- 1 
    end 

    # for lower bound of root finding: use something close to zero 
    lower_bound = 1e-32 
    
    #### Step 2: Solve for m_R* using FOCs ####

    # upper_bound: can pick this as max (never want to spend beyond max subsidy (without constraints))
    upper_bound = ((theta .* epsilon_grid) ./ (cost_level .* cost_elast)).^(1.0 ./ (cost_elast - theta))

    # save results 
    optimal_m_R = zeros(length(epsilon_grid))

    # find root 
    @inbounds for row = 1:length(optimal_m_R)
        if epsilon_grid[row] < cost_level 
            optimal_m_R[row] = 0.0 
        else 
            optimal_m_R[row] = Roots.find_zero(m_R -> rent_seeking_FOC(m_R,z_tilde_grid[row], epsilon_grid[row]), (lower_bound,upper_bound[row]), Bisection())
        end 
    end 

    #### Step 3: Solve for remaining objects needed ####

    optimal_subsidy = epsilon_grid .* (optimal_m_R.^theta) .- cost_level .* (optimal_m_R.^cost_elast)
    optimal_revenue = z_tilde_grid .* ( (1.0 .+ optimal_subsidy ).^(elasticity_ratio .+ 1.0) )
    optimal_profits_C = (1-profit_tax) .* ((1.0 .- eta_tilde) .* optimal_revenue .- optimal_m_R)
    optimal_profits_NC = (1-profit_tax) .* (1.0 .- eta_tilde) .* z_tilde_grid
    TFPQ_model = (1.0 .+ optimal_subsidy) .* z_star_grid
    choose_C = (optimal_profits_C .- ((1-profit_tax) .* fixed_cost)) .> optimal_profits_NC
    optimal_subsidy = ifelse.(choose_C,optimal_subsidy,0.0)
    optimal_m_R = ifelse.(choose_C, optimal_m_R, 0.0)
    optimal_revenue = ifelse.(choose_C, optimal_revenue, z_tilde_grid)
    optimal_profits = ifelse.(choose_C, optimal_profits_C .- ((1-profit_tax).* fixed_cost), optimal_profits_NC)

    #### Return final results ####
    return (
        z_star = z_star_grid, epsilon = epsilon_grid, z_tilde = z_tilde_grid,
        choose_C = choose_C, optimal_subsidy = optimal_subsidy, optimal_m_R = optimal_m_R,
        optimal_revenue = optimal_revenue, optimal_profits_C = optimal_profits_C, 
        optimal_profits_NC = optimal_profits_NC, optimal_profits = optimal_profits, 
        TFPQ_model = (1.0 .+ optimal_subsidy) .* z_star_grid, 
        revenue_output_C = z_tilde_grid .* ((1.0 .+ optimal_subsidy).^elasticity_ratio),
        revenue_output_NC = z_tilde_grid)
end 

## compute subsidy  
function compute_subsidy(; z_star_grid, epsilon_grid, z_tilde_grid, param, tax = nothing)

    #### Step 1: Preparation ####

    # get parameters 
    theta = param.connect_theta 
    cost_level = param.connect_cost_level
    cost_elast = param.connect_cost_elasticity 
    fixed_cost = param.connect_fixed_cost 
    # sigma = param.sigma 

    # assign stat tax 
    if tax == nothing
        tax = copy(vat_tax)
    end 

    # specify FOC 
    function rent_seeking_FOC(m_R, z_tilde_value, epsilon)
        return z_tilde_value .* ( (1.0 .+ (epsilon .*((m_R).^(theta)) .- cost_level .* (m_R.^cost_elast))).^elasticity_ratio ) .* ( (theta .* epsilon .* ((m_R).^(theta - 1.0)) .- cost_level .* cost_elast .* (m_R.^(cost_elast-1.0))) ) .- 1 
    end 

    # for lower bound of root finding: use something close to zero 
    lower_bound = 1e-32 
    
    #### Step 2: Solve for m_R* using FOCs ####

    # upper_bound: can pick this as max (never want to spend beyond max subsidy (without constraints))
    upper_bound = ((theta .* epsilon_grid) ./ (cost_level .* cost_elast)).^(1.0 ./ (cost_elast - theta))

    # save results 
    optimal_m_R = zeros(length(epsilon_grid))

    # find root 
    @inbounds for row = 1:length(optimal_m_R)
        if epsilon_grid[row] < cost_level 
            optimal_m_R[row] = 0.0 
        else 
            optimal_m_R[row] = Roots.find_zero(m_R -> rent_seeking_FOC(m_R,z_tilde_grid[row], epsilon_grid[row]), (lower_bound,upper_bound[row]), Bisection())
        end 
    end 
    
    #### Step 3: Solve for remaining objects needed ####

    optimal_subsidy = epsilon_grid .* (optimal_m_R.^theta) .- cost_level .* (optimal_m_R.^cost_elast)
    optimal_revenue = z_tilde_grid .* ( (1.0 .+ optimal_subsidy ).^(elasticity_ratio .+ 1.0) )
    optimal_profits_C = (1-profit_tax) .* ((1.0 .- eta_tilde) .* optimal_revenue .- optimal_m_R)
    optimal_profits_NC = (1-profit_tax) .* (1.0 .- eta_tilde) .* z_tilde_grid
    TFPQ_model = (1.0 .+ optimal_subsidy) .* z_star_grid
    choose_C = (optimal_profits_C .- ((1-profit_tax) .* fixed_cost)) .> optimal_profits_NC
    optimal_subsidy = ifelse.(choose_C,optimal_subsidy,0.0)
    optimal_m_R = ifelse.(choose_C, optimal_m_R, 0.0)
    optimal_revenue = ifelse.(choose_C, optimal_revenue, z_tilde_grid)
    optimal_profits = ifelse.(choose_C, optimal_profits_C .- ((1-profit_tax).* fixed_cost), optimal_profits_NC)

    final_results = DataFrame(
        z_star = z_star_grid,
        epsilon = epsilon_grid,
        z_tilde = z_tilde_grid,
        choose_C = choose_C, 
        optimal_subsidy = optimal_subsidy, 
        optimal_m_R = optimal_m_R, 
        optimal_revenue = optimal_revenue, 
        optimal_profits_C = optimal_profits_C, 
        optimal_profits_NC = optimal_profits_NC, 
        optimal_profits = optimal_profits,
        TFPQ_model = (1.0 .+ optimal_subsidy) .* z_star_grid, 
        revenue_output_C = z_tilde_grid .* ((1.0 .+ optimal_subsidy).^elasticity_ratio),
        revenue_output_NC = z_tilde_grid
    ) 

    #### Return final results ####
    return final_results
end 

# compute profits 
function compute_profits_grid_baseline(; z_star_grid, n_eps, param)

    # baseline x_star (enforcing w = 1.0) 
    x_star = ( (((1-vat_tax)*alpha_gross/(rental_rate - 1.0))^alpha_gross)*(((1-vat_tax)*beta_gross)^beta_gross)*((gamma_gross)^gamma_gross) )

    # get z_tilde_grid 
    z_tilde_grid = (z_star_grid .* x_star).^(1/(1-eta_tilde))

    # length 
    n_rows = length(z_tilde_grid)

    # create results vectors 
    expected_profits_C = zeros(n_rows)
    expected_profits_NC = zeros(n_rows)
    expected_m_R_C = zeros(n_rows)
    expected_subsidies_C = zeros(n_rows)
    expected_revenue_output_C = zeros(n_rows)
    expected_revenue_output_NC = zeros(n_rows)
    expected_subsidy_rate_C = zeros(n_rows)
    expected_revenue_C = zeros(n_rows)
    expected_revenue_NC = copy(z_tilde_grid)
    subsidy_variation = zeros(n_rows*n_eps)
    epsilon_variation = zeros(n_rows*n_eps)
    epsilon_proba_variation = zeros(n_rows*n_eps)
    rent_seeking_share_variation = zeros(n_rows*n_eps)
    z_star_variation = zeros(Float64, n_rows*n_eps)
    profits_C_variation = zeros(Float64, n_rows*n_eps)
    revenue_C_variation = zeros(Float64, n_rows*n_eps)
    m_R_variation = zeros(Float64, n_rows*n_eps)

    # for each row 
    @inbounds for row = 1:n_rows

        ### Substep 1: Draw epsilon conditional on z_star & the probability 
        Random.seed!(12345)
        normal_distrib = Normal(param.alpha_eps + param.beta_eps*log(z_star_grid[row]), sqrt(param.variance_eps_z))
        epsilon = rand(normal_distrib, n_eps)
        proba_epsilon = pdf.(normal_distrib, epsilon)

        # normalize probability to make sure it sums to 1  
        proba_epsilon = proba_epsilon ./ sum(proba_epsilon)

        ### Substep 2: Get optimal profits 

        results = compute_subsidy_baseline(
            z_star_grid = ones(n_eps) .* z_star_grid[row], 
            epsilon_grid = epsilon, 
            param = param)

        # save z & full profits variation 
        z_star_variation[((row-1)*n_eps + 1):(row*n_eps)] .= z_star_grid[row]
        profits_C_variation[((row-1)*n_eps + 1):(row*n_eps)] = results.optimal_profits_C
        revenue_C_variation[((row-1)*n_eps + 1):(row*n_eps)] = results.optimal_revenue

        # get expected profits by taking weighted mean across all possible epsilon 
        expected_profits_C[row] = sum(results.optimal_profits_C .* proba_epsilon)
        expected_profits_NC[row] = sum(results.optimal_profits_NC .* proba_epsilon)
  
        # collect further objects needed for later 
        expected_m_R_C[row] = sum(results.optimal_m_R .* proba_epsilon)
        expected_subsidies_C[row] = sum(results.optimal_revenue .* (results.optimal_subsidy ./ (1.0 .+ results.optimal_subsidy)) .* proba_epsilon)
        expected_revenue_output_C[row] = sum(results.revenue_output_C .* proba_epsilon)
        expected_revenue_output_NC[row] = sum(results.revenue_output_NC .* proba_epsilon)
        expected_subsidy_rate_C[row] = sum(results.optimal_subsidy .* proba_epsilon)
        expected_revenue_C[row] = sum(results.optimal_revenue .* proba_epsilon)

        # also save all subsidy variation 
        subsidy_variation[((row-1)*n_eps + 1):(row*n_eps)] = results.optimal_subsidy
        epsilon_variation[((row-1)*n_eps + 1):(row*n_eps)] = epsilon
        epsilon_proba_variation[((row-1)*n_eps + 1):(row*n_eps)] = proba_epsilon
        rent_seeking_share_variation[((row-1)*n_eps + 1):(row*n_eps)] = results.optimal_m_R ./ (results.optimal_m_R .+ (gamma_gross .* results.optimal_revenue))
        m_R_variation[((row-1)*n_eps + 1):(row*n_eps)] = results.optimal_m_R
    end 

    expected_profits = param.proba_c .* expected_profits_C .+ (1.0 - param.proba_c) .* expected_profits_NC

    # adjust all variables by proba_c ? 
    expected_subsidies = param.proba_c .* expected_subsidies_C
    expected_revenue_output = param.proba_c .* expected_revenue_output_C .+ (1.0 - param.proba_c) .* expected_revenue_output_NC
    expected_revenue = param.proba_c .* expected_revenue_C .+ (1.0 - param.proba_c) .* expected_revenue_NC      

    return (
        expected_profits_C = expected_profits_C,
        expected_profits_NC = expected_profits_NC,
        expected_profits = expected_profits, 
        expected_m_R_C = expected_m_R_C,
        expected_subsidies_C = expected_subsidies_C, 
        expected_revenue_output_C = expected_revenue_output_C,
        expected_revenue_output_NC = expected_revenue_output_NC,
        expected_revenue_output = expected_revenue_output, 
        expected_subsidy_rate_C = expected_subsidy_rate_C,
        expected_revenue_C = expected_revenue_C,
        expected_revenue_NC = expected_revenue_NC, 
        expected_revenue = expected_revenue,
        subsidy_variation = subsidy_variation, 
        epsilon_variation = epsilon_variation, 
        epsilon_proba_variation = epsilon_proba_variation,
        rent_seeking_share_variation = rent_seeking_share_variation,
        z_star_variation = z_star_variation, 
        profits_C_variation = profits_C_variation, 
        revenue_C_variation = revenue_C_variation, 
        m_R_variation = m_R_variation
        )
end

# compute policies for visualization (keeps epsilons fixed)
function compute_policies_baseline(; z_star_grid, eps, param)

    # baseline x_star (enforcing w = 1.0) 
    x_star = ( (((1-vat_tax)*alpha_gross/(rental_rate - 1.0))^alpha_gross)*(((1-vat_tax)*beta_gross)^beta_gross)*((gamma_gross)^gamma_gross) )

    # get z_tilde_grid 
    z_tilde_grid = (z_star_grid .* x_star).^(1/(1-eta_tilde))

    # length 
    n_rows = length(z_tilde_grid)
    n_eps = length(eps) 

    # create results vectors 
    subsidy_variation = zeros(n_rows,n_eps)
    epsilon_variation = zeros(n_rows,n_eps)
    rent_seeking_share_variation = zeros(n_rows,n_eps)
    z_star_variation = zeros(Float64, n_rows, n_eps)
    profits_C_variation = zeros(Float64, n_rows, n_eps)
    revenue_C_variation = zeros(Float64, n_rows, n_eps)
    m_R_variation = zeros(Float64, n_rows, n_eps)

    # for each row 
    @inbounds for row = 1:n_rows

        ### Get optimal profits 

        results = compute_subsidy_baseline(
            z_star_grid = ones(n_eps) .* z_star_grid[row], 
            epsilon_grid = eps, 
            param = param)

        # save z & full profits variation 
        z_star_variation[row,:] .= z_star_grid[row] # ((row-1)*n_eps + 1):(row*n_eps)
        profits_C_variation[row,:] = results.optimal_profits_C # ((row-1)*n_eps + 1):(row*n_eps)
        revenue_C_variation[row,:] = results.optimal_revenue # ((row-1)*n_eps + 1):(row*n_eps)

        # also save all subsidy variation 
        subsidy_variation[row,:] = results.optimal_subsidy
        epsilon_variation[row,:] = eps
        rent_seeking_share_variation[row,:] = results.optimal_m_R ./ (results.optimal_m_R .+ (gamma_gross .* results.optimal_revenue))
        m_R_variation[row,:] = results.optimal_m_R
    end 

    return (
        subsidy_variation = subsidy_variation, 
        epsilon_variation = epsilon_variation, 
        rent_seeking_share_variation = rent_seeking_share_variation,
        z_star_variation = z_star_variation, 
        profits_C_variation = profits_C_variation, 
        revenue_C_variation = revenue_C_variation, 
        m_R_variation = m_R_variation
        )
end

# compute profits for any counterfactual 
function compute_profits_grid_cf(; w, Y, tax = nothing, n_eps = 500, param, z_grid = z_grid_baseline)

    # get parameters 
    theta = param.connect_theta 
    cost_level = param.connect_cost_level
    cost_elast = param.connect_cost_elasticity 
    fixed_cost = param.connect_fixed_cost 
    sigma = param.sigma 

    # assign stat tax 
    if tax == nothing
        tax = copy(vat_tax)
    end 

    # Step 1: compute z_star & x_star (both are functions of w & Y)
    x_bar = (Y^(1/sigma))
    z_star = (z_grid.^((sigma - 1)/sigma)) .* x_bar
    x_star = ( (((1-tax)*alpha_gross/(rental_rate - 1.0))^alpha_gross)*(((1-tax)*beta_gross/w)^beta_gross)*((gamma_gross)^gamma_gross) )

    # get z_tilde_grid 
    z_tilde_grid = (z_star_grid .* x_star).^(1/(1-eta_tilde))

    # length 
    n_rows = length(z_tilde_grid)

    # create results vectors 
    expected_profits_C = zeros(n_rows)
    expected_profits_NC = zeros(n_rows)
    expected_m_R_C = zeros(n_rows)
    expected_subsidies_C = zeros(n_rows)
    expected_revenue_output_C = zeros(n_rows)
    expected_revenue_output_NC = zeros(n_rows)
    expected_subsidy_rate_C = zeros(n_rows)
    expected_revenue_C = zeros(n_rows)
    expected_revenue_NC = copy(z_tilde_grid)

    # for each row 
    @inbounds for row = 1:n_rows

        ### Substep 1: Draw epsilon conditional on z_star & the probability 
        Random.seed!(12345)
        normal_distrib = Normal(param.alpha_eps + param.beta_eps*log(z_grid[row]), sqrt(param.variance_eps_z))
        epsilon = rand(normal_distrib, n_eps)
        proba_epsilon = pdf.(normal_distrib, epsilon)

        # normalize probability to make sure it sums to 1  
        proba_epsilon = proba_epsilon ./ sum(proba_epsilon)

        ### Substep 2: Get optimal profits 

        results = compute_subsidy(
            z_star_grid = ones(n_eps) .* z_star_grid[row], 
            epsilon_grid = epsilon, 
            z_tilde_grid = ones(n_eps) .* z_tilde_grid[row],
            param = param, 
            tax = tax)

        # get expected profits by taking weighted mean across all possible epsilon 
        expected_profits_C[row] = sum(results.optimal_profits_C .* proba_epsilon)
        expected_profits_NC[row] = sum(results.optimal_profits_NC .* proba_epsilon)
  
        # collect further objects needed for later 
        expected_m_R_C[row] = sum(results.optimal_m_R .* proba_epsilon)
        expected_subsidies_C[row] = sum(results.optimal_revenue .* (results.optimal_subsidy ./ (1.0 .+ results.optimal_subsidy)) .* proba_epsilon)
        expected_revenue_output_C[row] = sum(results.revenue_output_C .* proba_epsilon)
        expected_revenue_output_NC[row] = sum(results.revenue_output_NC .* proba_epsilon)
        expected_subsidy_rate_C[row] = sum(results.optimal_subsidy .* proba_epsilon)
        expected_revenue_C[row] = sum(results.optimal_revenue .* proba_epsilon)
    end 

    expected_profits = param.proba_c .* expected_profits_C .+ (1.0 - param.proba_c) .* expected_profits_NC

    # adjust all variables by proba_c ? 
    expected_subsidies = param.proba_c .* expected_subsidies_C
    expected_revenue_output = param.proba_c .* expected_revenue_output_C .+ (1.0 - param.proba_c) .* expected_revenue_output_NC
    expected_revenue = param.proba_c .* expected_revenue_C .+ (1.0 - param.proba_c) .* expected_revenue_NC      

    return (
        expected_profits_C = expected_profits_C,
        expected_profits_NC = expected_profits_NC,
        expected_profits = expected_profits, 
        expected_m_R_C = expected_m_R_C,
        expected_subsidies = expected_subsidies, 
        expected_subsidies_C = expected_subsidies_C, 
        expected_revenue_output_C = expected_revenue_output_C,
        expected_revenue_output_NC = expected_revenue_output_NC,
        expected_revenue_output = expected_revenue_output, 
        expected_subsidy_rate_C = expected_subsidy_rate_C,
        expected_revenue_C = expected_revenue_C,
        expected_revenue_NC = expected_revenue_NC, 
        expected_revenue = expected_revenue
        )
end

# compute profits for fix subsidy counterfactual
function compute_profits_grid_cf_fix_subsidy(; w, Y, subsidy, tax = nothing, param, z_grid = z_grid)

    # get parameters 
    sigma = param.sigma 

    # assign stat tax 
    if tax == nothing
        tax = copy(vat_tax)
    end 

    # Step 1: compute z_star & x_star (both are functions of w & Y)
    x_bar = (Y^(1/sigma))
    z_star = (z_grid.^((sigma - 1)/sigma)) .* x_bar
    x_star = ( (((1-tax)*alpha_gross/(rental_rate - 1.0))^alpha_gross)*(((1-tax)*beta_gross/w)^beta_gross)*((gamma_gross)^gamma_gross) )

    # get z_tilde_grid 
    z_tilde_grid = (z_star_grid .* x_star).^(1/(1-eta_tilde))

    # length 
    n_rows = length(z_tilde_grid)

    # create results vectors 
    expected_profits_C = zeros(n_rows)
    expected_profits_NC = zeros(n_rows)
    expected_m_R_C = zeros(n_rows)
    expected_subsidies_C = zeros(n_rows)
    expected_revenue_output_C = zeros(n_rows)
    expected_revenue_output_NC = zeros(n_rows)
    expected_subsidy_rate_C = zeros(n_rows)
    expected_revenue_C = zeros(n_rows)
    expected_revenue_NC = copy(z_tilde_grid)

    optimal_m_R = 0.0 # assume zero costs 
    optimal_subsidy = subsidy 
    optimal_revenue_C = z_tilde_grid .* ( (1.0 .+ subsidy ).^(elasticity_ratio .+ 1.0) )
    optimal_revenue_NC = z_tilde_grid 
    optimal_profits_C = (1-profit_tax) .* ((1.0 .- eta_tilde) .* optimal_revenue_C .- optimal_m_R) 
    optimal_profits_NC = (1-profit_tax) .* (1.0 .- eta_tilde) .* z_tilde_grid
    TFPQ_model_C = (1.0 .+ subsidy) .* z_star_grid
    TFPQ_model_NC = z_star_grid
    subsidy_amount = optimal_revenue_C .* (optimal_subsidy ./ (1.0 .+ optimal_subsidy))
    revenue_output_C = z_tilde_grid .* ((1.0 .+ optimal_subsidy).^elasticity_ratio)
    revenue_output_NC = z_tilde_grid 

    # choose_C = optimal_profits_C .> optimal_profits_NC

    # construct weighted/expected versions
    expected_profits = param.proba_c .* optimal_profits_C .+ (1.0 - param.proba_c) .* optimal_profits_NC
    expected_subsidies = param.proba_c .* subsidy_amount
    expected_revenue = param.proba_c .* optimal_revenue_C .+ (1.0 - param.proba_c) .* optimal_revenue_NC      
    expected_revenue_output = param.proba_c .* revenue_output_C .+ (1.0 - param.proba_c) .* revenue_output_NC    
    
    return (
        expected_profits = expected_profits, 
        expected_subsidies = expected_subsidies, 
        expected_revenue = expected_revenue, 
        expected_revenue_output = expected_revenue_output,
        subsidy_amount = subsidy_amount 
        )
end

# compute subsidies, profits, revenue etc with wedges 
function compute_subsidy_wedges_baseline(; z_star_grid, epsilon_grid, wedges_grid, relative_wedge, wage = 1.0, tax = vat_tax, param)

    #### Step 1: Preparation ####

    # get parameters 
    theta = param.connect_theta 
    cost_level = param.connect_cost_level
    cost_elast = param.connect_cost_elasticity 
    fixed_cost = param.connect_fixed_cost 

    # baseline x_star (enforcing w = 1.0) 
    x_star_grid = ( (((1-tax)*alpha_gross/(rental_rate - 1.0))^alpha_gross)*(((1-tax)*beta_gross/wage)^beta_gross)*((gamma_gross)^gamma_gross) ) .* (1 ./ wedges_grid)

    # get z_tilde_grid 
    z_tilde_grid = (z_star_grid .* x_star_grid).^(1/(1-eta_tilde))

    # get capital and labor wedge from relative wedge and wedges grid 
    capital_wedge = exp.(log.(wedges_grid) .* relative_wedge ./ alpha_gross) .- 1.0
    labor_wedge = exp.(log.(wedges_grid) .* (1.0 .- relative_wedge) ./ beta_gross) .- 1.0

    # restrict max and min of capital/labor wedges (enforce 5th and 95th percentile in data) 
    capital_wedge[capital_wedge .< -0.7434] .= -0.7434 # 
    capital_wedge[capital_wedge .> 3.2371] .= 3.2371 # 

    labor_wedge[labor_wedge .< -0.4761] .= -0.4761 # 
    labor_wedge[labor_wedge .> 1.2882] .= 1.2882 # 
  
    # specify FOC 
    function rent_seeking_FOC(m_R, z_tilde_value, epsilon)
        return z_tilde_value .* ( (1.0 .+ (epsilon .*((m_R).^(theta)) .- cost_level .* (m_R.^cost_elast))).^elasticity_ratio ) .* ( (theta .* epsilon .* ((m_R).^(theta - 1.0)) .- cost_level .* cost_elast .* (m_R.^(cost_elast-1.0))) ) .- 1 
    end 

    # for lower bound of root finding: use something close to zero 
    lower_bound = 1e-32 
    
    #### Step 2: Solve for m_R* using FOCs ####

    # upper_bound: can pick this as max (never want to spend beyond max subsidy (without constraints))
    upper_bound = ((theta .* epsilon_grid) ./ (cost_level .* cost_elast)).^(1.0 ./ (cost_elast - theta))

    # save results 
    optimal_m_R = zeros(length(epsilon_grid))

    # find root 
    @inbounds for row = 1:length(optimal_m_R)
        if epsilon_grid[row] < cost_level 
            optimal_m_R[row] = 0.0 
        else 
            optimal_m_R[row] = Roots.find_zero(m_R -> rent_seeking_FOC(m_R,z_tilde_grid[row], epsilon_grid[row]), (lower_bound,upper_bound[row]), Bisection())
        end 
    end 

    #### Step 3: Solve for remaining objects needed ####

    optimal_subsidy = epsilon_grid .* (optimal_m_R.^theta) .- cost_level .* (optimal_m_R.^cost_elast)
    optimal_revenue = z_tilde_grid .* ( (1.0 .+ optimal_subsidy ).^(elasticity_ratio .+ 1.0) )
    optimal_profits_C = (1-profit_tax) .* ((1.0 .- eta_tilde) .* optimal_revenue .- optimal_m_R)
    optimal_profits_NC = (1-profit_tax) .* (1.0 .- eta_tilde) .* z_tilde_grid
    choose_C = (optimal_profits_C .- ((1-profit_tax) .* fixed_cost)) .> optimal_profits_NC
    optimal_subsidy = ifelse.(choose_C,optimal_subsidy,0.0)
    optimal_m_R = ifelse.(choose_C, optimal_m_R, 0.0)
    optimal_revenue = ifelse.(choose_C, optimal_revenue, z_tilde_grid)
    optimal_profits = ifelse.(choose_C, optimal_profits_C .- ((1-profit_tax).* fixed_cost), optimal_profits_NC)
    TFPQ_model = (1.0 .+ optimal_subsidy) .* z_star_grid
    optimal_capital = (alpha_gross .* (1-tax) .* optimal_revenue) ./ ((1.0 .+ capital_wedge) .* (rental_rate - 1.0))
    optimal_labor = (beta_gross .* (1-tax) .* optimal_revenue) ./ ((1.0 .+ labor_wedge) .* wage)

    #### Return final results ####
    return (
        z_star = z_star_grid, epsilon = epsilon_grid, z_tilde = z_tilde_grid,
        choose_C = choose_C, optimal_subsidy = optimal_subsidy, optimal_m_R = optimal_m_R,
        optimal_revenue = optimal_revenue, optimal_profits_C = optimal_profits_C, 
        optimal_profits_NC = optimal_profits_NC, optimal_profits = optimal_profits, 
        TFPQ_model = TFPQ_model, 
        optimal_capital = optimal_capital, 
        optimal_labor = optimal_labor, 
        revenue_output_C = z_tilde_grid .* ((1.0 .+ optimal_subsidy).^elasticity_ratio),
        revenue_output_NC = z_tilde_grid)
end 

# compute subsidies, profits, revenue etc with wedges 
function compute_subsidy_wedges_baseline_NC(; z_star_grid, wedges_grid, relative_wedge, w = 1.0, tax = vat_tax, param)

    #### Step 1: Preparation ####

    # get parameters 
    theta = param.connect_theta 
    cost_level = param.connect_cost_level
    cost_elast = param.connect_cost_elasticity 
    fixed_cost = param.connect_fixed_cost 

    # baseline x_star (enforcing w = 1.0) 
    x_star_grid = ( (((1-tax)*alpha_gross/(rental_rate - 1.0))^alpha_gross)*(((1-tax)*beta_gross/w)^beta_gross)*((gamma_gross)^gamma_gross) ) .* (1 ./ wedges_grid)

    # get z_tilde_grid 
    z_tilde_grid = (z_star_grid .* x_star_grid).^(1/(1-eta_tilde))

    # get capital and labor wedge from relative wedge and wedges grid 
    capital_wedge = exp.(log.(wedges_grid) .* relative_wedge ./ alpha_gross) .- 1.0
    labor_wedge = exp.(log.(wedges_grid) .* (1.0 .- relative_wedge) ./ beta_gross) .- 1.0

    # restrict max and min of capital/labor wedges (enforce 5th and 95th percentile in data) 
    capital_wedge[capital_wedge .< -0.7565] .= -0.7565 # 
    capital_wedge[capital_wedge .> 3.2371] .= 3.2371 # 

    labor_wedge[labor_wedge .< -0.558] .= -0.558 # 
    labor_wedge[labor_wedge .> 1.288] .= 1.288 # 

    #### Step 2: Solve for remaining objects needed ####

    optimal_subsidy = zeros(length(wedges_grid))

    optimal_revenue = z_tilde_grid
    optimal_profits_C = (1-profit_tax) .* (1.0 .- eta_tilde) .* optimal_revenue
    optimal_profits_NC = (1-profit_tax) .* (1.0 .- eta_tilde) .* z_tilde_grid

    optimal_capital = (alpha_gross .* (1-vat_tax) .* optimal_revenue) ./ ((1 .+ capital_wedge) .* (rental_rate - 1.0))
    optimal_labor = (beta_gross .* (1-vat_tax) .* optimal_revenue) ./ ((1 .+ labor_wedge) .* w)

    #### Return final results ####
    return (
        z_star = z_star_grid, z_tilde = z_tilde_grid,
        optimal_subsidy = optimal_subsidy,
        optimal_revenue = optimal_revenue, optimal_profits_C = optimal_profits_C, 
        optimal_profits_NC = optimal_profits_NC,
        optimal_capital = optimal_capital, 
        optimal_labor = optimal_labor, 
        capital_wedge = capital_wedge,
        labor_wedge = labor_wedge, 
        revenue_output_C = z_tilde_grid,
        revenue_output_NC = z_tilde_grid)
end 

# compute expected profits (only variation over z_star) with wedges 
function compute_profits_grid_wedges_baseline(; z_star_grid, n_draws, param)

    # length 
    n_rows = length(z_star_grid)

    # create results vectors 
    expected_profits_C = zeros(n_rows)
    expected_profits_NC = zeros(n_rows)
    expected_m_R_C = zeros(n_rows)
    expected_subsidies_C = zeros(n_rows)
    expected_revenue_output_C = zeros(n_rows)
    expected_revenue_output_NC = zeros(n_rows)
    expected_subsidy_rate_C = zeros(n_rows)
    expected_revenue_C = zeros(n_rows)
    expected_revenue_NC = zeros(n_rows)
    expected_labor_C = zeros(n_rows)
    expected_labor_NC = zeros(n_rows)
    expected_capital_C = zeros(n_rows)
    expected_capital_NC = zeros(n_rows)
    expected_capital_wedge = zeros(n_rows)
    expected_labor_wedge = zeros(n_rows)
    subsidy_variation = zeros(n_rows*n_draws)
    epsilon_variation = zeros(n_rows*n_draws)
    epsilon_proba_variation = zeros(n_rows*n_draws)
    rent_seeking_share_variation = zeros(n_rows*n_draws)

    # for each row 
    @inbounds for row = 1:n_rows

        ### Substep 1: Draw epsilon conditional on z_star & the probability 
        Random.seed!(12345)
        normal_distrib_eps = Normal(param.alpha_eps + param.beta_eps*log(z_star_grid[row]), sqrt(param.variance_eps_z))
        epsilon = rand(normal_distrib_eps, n_draws)
        proba_epsilon = pdf.(normal_distrib_eps, epsilon)

        # normalize probability to make sure it sums to 1  
        proba_epsilon = proba_epsilon ./ sum(proba_epsilon)

        ### Substep 2: Draw wedge conditional on z_star & the probability 
        Random.seed!(12345)
        normal_distrib_wedge_C = Normal(
            param.mean_tau_C + param.corr_wedge_z_C*(param.sd_tau_C/param.sd_log_z_star)*(log(z_star_grid[row]) - param.mean_log_z_star), 
            sqrt((1-(param.corr_wedge_z_C^2)))*param.sd_tau_C) 
        wedge_C = exp.(rand(normal_distrib_wedge_C, n_draws))
        proba_wedge_C = pdf.(normal_distrib_wedge_C, log.(wedge_C))

        Random.seed!(12345 + row)
        relative_capital_wedges_C = sample(capital_wedge_share_C, n_draws)        
        
        Random.seed!(12345)
        normal_distrib_wedge_NC = Normal(
            param.mean_tau_NC + param.corr_wedge_z_NC*(param.sd_tau_NC/param.sd_log_z_star)*(log(z_star_grid[row]) - param.mean_log_z_star), 
            sqrt((1-(param.corr_wedge_z_NC^2)))*param.sd_tau_NC) 
        wedge_NC = exp.(rand(normal_distrib_wedge_NC, n_draws))
        proba_wedge_NC = pdf.(normal_distrib_wedge_NC, log.(wedge_NC))

        Random.seed!(12346 + row)
        relative_capital_wedges_NC = sample(capital_wedge_share_NC, n_draws) 

        # normalize probability to make sure it sums to 1  
        proba_wedge_C = proba_wedge_C ./ sum(proba_wedge_C)
        proba_wedge_NC = proba_wedge_NC ./ sum(proba_wedge_NC)

        ## Get joint probability 
        proba_joint_C = (proba_wedge_C .* proba_epsilon) ./ sum(proba_wedge_C .* proba_epsilon)

        ### Substep 3: construct x_star & z_tilde 
        x_star_grid_C = ( (((1-vat_tax)*alpha_gross/(rental_rate - 1.0))^alpha_gross)*(((1-vat_tax)*beta_gross)^beta_gross)*((gamma_gross)^gamma_gross) ) .* (1/wedge_C)
        x_star_grid_NC = ( (((1-vat_tax)*alpha_gross/(rental_rate - 1.0))^alpha_gross)*(((1-vat_tax)*beta_gross)^beta_gross)*((gamma_gross)^gamma_gross) ) .* (1/wedge_NC)

        z_tilde_grid_C = (z_star_grid[row] .* x_star_grid_C).^(1/(1-eta_tilde))
        z_tilde_grid_NC = (z_star_grid[row] .* x_star_grid_NC).^(1/(1-eta_tilde))

        ### Substep 4: Get optimal profits 

        results_C = compute_subsidy_wedges_baseline(
            z_star_grid = ones(n_draws) .* z_star_grid[row], 
            epsilon_grid = epsilon, 
            wedges_grid = wedge_C, 
            relative_wedge = relative_capital_wedges_C, 
            param = param) 

        # Version without epsilon, but with wedges!
        results_NC = compute_subsidy_wedges_baseline_NC(
                z_star_grid = ones(n_draws) .* z_star_grid[row], 
                wedges_grid = wedge_NC, 
                relative_wedge = relative_capital_wedges_NC,
                param = param)

        # get expected profits by taking weighted mean across all possible epsilon 
        expected_profits_C[row] = sum(results_C.optimal_profits_C .* proba_joint_C)
        expected_profits_NC[row] = sum(results_NC.optimal_profits_NC .* proba_wedge_NC)
  
        # collect further objects needed for later 
        expected_m_R_C[row] = sum(results_C.optimal_m_R .* proba_joint_C)
        expected_subsidies_C[row] = sum(results_C.optimal_revenue .* (results_C.optimal_subsidy ./ (1.0 .+ results_C.optimal_subsidy)) .* proba_joint_C)
        expected_revenue_output_C[row] = sum(results_C.revenue_output_C .* proba_joint_C)
        expected_revenue_output_NC[row] = sum(results_NC.revenue_output_NC .* proba_wedge_NC)
        expected_subsidy_rate_C[row] = sum(results_C.optimal_subsidy .* proba_joint_C)
        expected_revenue_C[row] = sum(results_C.optimal_revenue .* proba_joint_C)
        expected_revenue_NC[row] = sum(results_NC.optimal_revenue .* proba_wedge_NC)

        # save expected labor and capital 
        expected_labor_C[row] = sum(results_C.optimal_labor .* proba_joint_C)
        expected_labor_NC[row] = sum(results_NC.optimal_labor .* proba_wedge_NC)
        expected_capital_C[row] = sum(results_C.optimal_capital .* proba_joint_C)
        expected_capital_NC[row] = sum(results_NC.optimal_capital .* proba_wedge_NC)

        expected_capital_wedge[row] = sum(results_NC.capital_wedge .* proba_wedge_NC)
        expected_labor_wedge[row] = sum(results_NC.labor_wedge .* proba_wedge_NC)

        # also save all subsidy variation 
        subsidy_variation[((row-1)*n_draws + 1):(row*n_draws)] = results_C.optimal_subsidy
        epsilon_variation[((row-1)*n_draws + 1):(row*n_draws)] = epsilon
        epsilon_proba_variation[((row-1)*n_draws + 1):(row*n_draws)] = proba_epsilon
        rent_seeking_share_variation[((row-1)*n_draws + 1):(row*n_draws)] = results_C.optimal_m_R ./ (results_C.optimal_m_R .+ (gamma_gross .* results_C.optimal_revenue))
    end 

    expected_profits = param.proba_c .* expected_profits_C .+ (1.0 - param.proba_c) .* expected_profits_NC

    # adjust all variables by proba_c ? 
    expected_subsidies = param.proba_c .* expected_subsidies_C
    expected_revenue_output = param.proba_c .* expected_revenue_output_C .+ (1.0 - param.proba_c) .* expected_revenue_output_NC
    expected_revenue = param.proba_c .* expected_revenue_C .+ (1.0 - param.proba_c) .* expected_revenue_NC  
    
    # save expected labor and capital
    expected_labor = param.proba_c .* expected_labor_C .+ (1.0 - param.proba_c) .* expected_labor_NC
    expected_capital = param.proba_c .* expected_capital_C .+ (1.0 - param.proba_c) .* expected_capital_NC

    return (
        expected_profits_C = expected_profits_C,
        expected_profits_NC = expected_profits_NC,
        expected_profits = expected_profits, 
        expected_m_R_C = expected_m_R_C,
        expected_subsidies_C = expected_subsidies_C, 
        expected_revenue_output_C = expected_revenue_output_C,
        expected_revenue_output_NC = expected_revenue_output_NC,
        expected_revenue_output = expected_revenue_output, 
        expected_subsidy_rate_C = expected_subsidy_rate_C,
        expected_revenue_C = expected_revenue_C,
        expected_revenue_NC = expected_revenue_NC, 
        expected_revenue = expected_revenue,
        subsidy_variation = subsidy_variation, 
        epsilon_variation = epsilon_variation, 
        epsilon_proba_variation = epsilon_proba_variation,
        rent_seeking_share_variation = rent_seeking_share_variation,
        expected_labor = expected_labor,
        expected_labor_C = expected_labor_C,
        expected_capital = expected_capital,
        expected_capital_C = expected_capital_C,
        expected_labor_wedge = expected_labor_wedge,
        expected_capital_wedge = expected_capital_wedge
        )
end

# compute GE aggregates 
function compute_aggregates(; w, distrib, mass_firms, mass_entry, exit_proba, within_period_choices, VF_results, tax = nothing, entry_cost = entry_cost, param) 

    # assign stat tax 
    if tax == nothing
        tax = copy(vat_tax)
    end 

    # Get total (productive) inputs 
    productive_capital = sum( within_period_choices.expected_revenue .* alpha_gross .* (1-tax) .* distrib ./ (rental_rate - 1) ) * mass_firms
    productive_labor = sum( within_period_choices.expected_revenue .* beta_gross .* (1-tax) .* distrib ./ w ) * mass_firms
    productive_intermediates = sum( within_period_choices.expected_revenue .* gamma_gross .* distrib ) * mass_firms

    # Get total output
    total_output = sum( within_period_choices.expected_revenue_output .* distrib ) * mass_firms

    # Get total entry costs (in labor units)
    total_entry_costs = sum( entry_cost .* mass_entry )

    # Can back out total labor demand then 
    total_labor_demand = productive_labor + total_entry_costs

    # Get total rent-seeking activities 
    total_rent_seeking = param.proba_c * sum( within_period_choices.expected_m_R_C .* distrib ) * mass_firms

    # HH consumption under closed economy 
    total_HH_consumption_closed = total_output - delta*productive_capital - productive_intermediates - total_rent_seeking # (enforce aggregate resource constraint -- only true if closed economy or BOP = 0) 

    # total tax revenues and subsidies
    cit_revenue = sum( within_period_choices.expected_profits .* distrib) * (profit_tax/(1-profit_tax)) * mass_firms
    vat_revenue = sum( tax .* ((1-gamma_gross) .* within_period_choices.expected_revenue .- param.proba_c .* within_period_choices.expected_m_R_C) .* distrib) * mass_firms
    gross_tax_revenues = cit_revenue + vat_revenue
    total_subsidies = param.proba_c * sum( within_period_choices.expected_subsidies_C .* distrib ) * mass_firms
    net_govt_transfers = gross_tax_revenues - total_subsidies     

    # compute total HH income (not including capital)
    total_net_profits = sum( within_period_choices.expected_profits .* distrib ) * mass_firms
    total_HH_income = net_govt_transfers + total_net_profits + w*total_labor_demand
    total_HH_income_noprofits = net_govt_transfers + w*total_labor_demand

    # Also compute total paid fixed costs 
    total_fixed_costs = sum( (1.0 .- exit_proba) .* VF_results.exp_fcost .* distrib ) * mass_firms

    ### return objects ### 
    return (
        productive_capital = productive_capital, productive_labor = productive_labor, productive_intermediates = productive_intermediates, 
        total_output = total_output, total_entry_costs = total_entry_costs, total_labor_demand = total_labor_demand,
        total_rent_seeking = total_rent_seeking, gross_tax_revenues = gross_tax_revenues, total_subsidies = total_subsidies,
        total_net_profits = total_net_profits,  total_HH_consumption_closed = total_HH_consumption_closed, 
        net_govt_transfers = net_govt_transfers, total_fixed_costs = total_fixed_costs, 
        total_HH_income = total_HH_income, total_HH_income_noprofits = total_HH_income_noprofits)
end 

# compute GE aggregates with entry costs in goods  
function compute_aggregates_goods(; w, distrib, mass_firms, mass_entry, within_period_choices, tax = nothing, entry_cost = entry_cost, param) 

    # assign stat tax 
    if tax == nothing
        tax = copy(vat_tax)
    end 

    # Get total (productive) inputs 
    productive_capital = sum( within_period_choices.expected_revenue .* alpha_gross .* (1-tax) .* distrib ./ (rental_rate - 1) ) * mass_firms
    productive_labor = sum( within_period_choices.expected_revenue .* beta_gross .* (1-tax) .* distrib ./ w ) * mass_firms
    productive_intermediates = sum( within_period_choices.expected_revenue .* gamma_gross .* distrib ) * mass_firms

    # Get total output
    total_output = sum( within_period_choices.expected_revenue_output .* distrib ) * mass_firms

    # Get total entry costs (scale with wage as explained in Appendix B)
    total_entry_costs = w * sum( entry_cost .* mass_entry ) * mass_firms 
    
    # Get total profits (subtracting entry costs)
    total_net_profits = sum( within_period_choices.expected_profits .* distrib ) * mass_firms - total_entry_costs
    
    # Can back out total labor supply then (only productive labor now, not entry costs!)
    total_labor_demand = productive_labor

    # Get total rent-seeking activities 
    total_rent_seeking = param.proba_c * sum( within_period_choices.expected_m_R_C .* distrib ) * mass_firms

    # HH consumption under closed economy 
    total_HH_consumption_closed = total_output - delta*productive_capital - productive_intermediates - total_rent_seeking # (enforce aggregate resource constraint -- only true if closed economy or BOP = 0) 

    # total tax revenues and subsidies
    cit_revenue = sum( within_period_choices.expected_profits .* distrib) * (profit_tax/(1-profit_tax)) * mass_firms
    vat_revenue = sum( tax .* ((1-gamma_gross) .* within_period_choices.expected_revenue .- param.proba_c .* within_period_choices.expected_m_R_C) .* distrib) * mass_firms
    gross_tax_revenues = cit_revenue + vat_revenue
    total_subsidies = param.proba_c * sum( within_period_choices.expected_subsidies_C .* distrib ) * mass_firms
    net_govt_transfers = gross_tax_revenues - total_subsidies     

    # compute total HH income (not including capital)
    total_HH_income = net_govt_transfers + total_net_profits + w*total_labor_demand
    total_HH_income_noprofits = net_govt_transfers + w*total_labor_demand

    ### return objects ### 
    return (
        productive_capital = productive_capital, productive_labor = productive_labor, productive_intermediates = productive_intermediates, 
        total_output = total_output, total_entry_costs = total_entry_costs, total_labor_demand = total_labor_demand,
        total_rent_seeking = total_rent_seeking, gross_tax_revenues = gross_tax_revenues, total_subsidies = total_subsidies,
        total_net_profits = total_net_profits,  total_HH_consumption_closed = total_HH_consumption_closed, 
        net_govt_transfers = net_govt_transfers,
        total_HH_income = total_HH_income, total_HH_income_noprofits = total_HH_income_noprofits)
end 

# 
function compute_aggregates_noC(; w, distrib, mass_firms, mass_entry, exit_proba, within_period_choices, VF_results, tax = nothing, param, entry_cost = entry_cost) 

    if tax == nothing
        tax = copy(vat_tax) 
    end 

    # Get total (productive) inputs 
    productive_capital = sum( within_period_choices.expected_revenue .* alpha_gross .* (1-tax) .* distrib ./ (rental_rate - 1) ) * mass_firms
    productive_labor = sum( within_period_choices.expected_revenue .* beta_gross .* (1-tax) .* distrib ./ w ) * mass_firms
    productive_intermediates = sum( within_period_choices.expected_revenue .* gamma_gross .* distrib ) * mass_firms

    # Get total output
    total_output = sum( within_period_choices.expected_revenue_output .* distrib ) * mass_firms
    total_net_profits = sum( within_period_choices.expected_profits .* distrib ) * mass_firms

    # Get total entry costs (in labor units)
    total_entry_costs = sum( entry_cost .* mass_entry )

    # Can back out total labor supply then 
    total_labor_demand = productive_labor + total_entry_costs

    # total tax revenues
    cit_revenue = sum( within_period_choices.expected_profits .* distrib) * (profit_tax/(1-profit_tax)) * mass_firms
    vat_revenue = sum( tax .* (1-gamma_gross) .* within_period_choices.expected_revenue .* distrib) * mass_firms
    tax_revenues = cit_revenue + vat_revenue

    # compute total HH income (not including capital)
    total_HH_income = tax_revenues + total_net_profits + w*total_labor_demand
    total_HH_income_noprofits = tax_revenues + w*total_labor_demand

    # closed-economy HH consumption 
    total_HH_consumption_closed = total_output - delta*productive_capital - productive_intermediates

    # Also compute total paid fixed costs 
    total_fixed_costs = sum( (1.0 .- exit_proba) .* VF_results.exp_fcost .* distrib ) * mass_firms

    ### return objects ### 
    return (
        productive_capital = productive_capital, productive_labor = productive_labor, productive_intermediates = productive_intermediates, 
        total_output = total_output, total_entry_costs = total_entry_costs, total_labor_demand = total_labor_demand,
        tax_revenues = tax_revenues, total_net_profits = total_net_profits, total_fixed_costs = total_fixed_costs, 
        total_HH_income = total_HH_income, total_HH_income_noprofits = total_HH_income_noprofits, 
        total_HH_consumption_closed = total_HH_consumption_closed)
end 

# with entry costs in goods 
function compute_aggregates_noC_goods(; w, distrib, mass_firms, mass_entry, exit_proba, within_period_choices, VF_results, tax = nothing, param, entry_cost = entry_cost) 

    if tax == nothing
        tax = copy(vat_tax) 
    end 

    # Get total (productive) inputs 
    productive_capital = sum( within_period_choices.expected_revenue .* alpha_gross .* (1-tax) .* distrib ./ (rental_rate - 1) ) * mass_firms
    productive_labor = sum( within_period_choices.expected_revenue .* beta_gross .* (1-tax) .* distrib ./ w ) * mass_firms
    productive_intermediates = sum( within_period_choices.expected_revenue .* gamma_gross .* distrib ) * mass_firms

    # Get total entry costs (scale with wage as explained in Appendix B)
    total_entry_costs = w * sum( entry_cost .* mass_entry ) * mass_firms 

    # Get total output
    total_output = sum( within_period_choices.expected_revenue_output .* distrib ) * mass_firms
    total_net_profits = sum( within_period_choices.expected_profits .* distrib ) * mass_firms - total_entry_costs
    
    # Can back out total labor supply then (only productive labor now, not entry costs!)
    total_labor_demand = productive_labor

    # total tax revenues
    cit_revenue = sum( within_period_choices.expected_profits .* distrib) * (profit_tax/(1-profit_tax)) * mass_firms
    vat_revenue = sum( tax .* (1-gamma_gross) .* within_period_choices.expected_revenue .* distrib) * mass_firms
    tax_revenues = cit_revenue + vat_revenue

    # compute total HH income (not including capital)
    total_HH_income = tax_revenues + total_net_profits + w*total_labor_demand
    total_HH_income_noprofits = tax_revenues + w*total_labor_demand

    # closed-economy HH consumption 
    total_HH_consumption_closed = total_output - delta*productive_capital - productive_intermediates

    # Also compute total paid fixed costs 
    total_fixed_costs = sum( (1.0 .- exit_proba) .* VF_results.exp_fcost .* distrib ) * mass_firms

    ### return objects ### 
    return (
        productive_capital = productive_capital, productive_labor = productive_labor, productive_intermediates = productive_intermediates, 
        total_output = total_output, total_entry_costs = total_entry_costs, total_labor_demand = total_labor_demand,
        tax_revenues = tax_revenues, total_net_profits = total_net_profits, total_fixed_costs = total_fixed_costs, 
        total_HH_income = total_HH_income, total_HH_income_noprofits = total_HH_income_noprofits, 
        total_HH_consumption_closed = total_HH_consumption_closed)
end 

function compute_aggregates_noC_goods_noEE(; w, distrib, mass_firms, within_period_choices, tax = nothing) 

    if tax == nothing
        tax = copy(vat_tax) 
    end 

    # Get total (productive) inputs 
    productive_capital = sum( within_period_choices.expected_revenue .* alpha_gross .* (1-tax) .* distrib ./ (rental_rate - 1) ) * mass_firms
    productive_labor = sum( within_period_choices.expected_revenue .* beta_gross .* (1-tax) .* distrib ./ w ) * mass_firms
    productive_intermediates = sum( within_period_choices.expected_revenue .* gamma_gross .* distrib ) * mass_firms

    # Get total output
    total_output = sum( within_period_choices.expected_revenue_output .* distrib ) * mass_firms
    total_net_profits = sum( within_period_choices.expected_profits .* distrib ) * mass_firms
    
    # Can back out total labor supply then (only productive labor now, not entry costs!)
    total_labor_demand = productive_labor

    # total tax revenues
    cit_revenue = sum( within_period_choices.expected_profits .* distrib) * (profit_tax/(1-profit_tax)) * mass_firms
    vat_revenue = sum( tax .* (1-gamma_gross) .* within_period_choices.expected_revenue .* distrib) * mass_firms
    tax_revenues = cit_revenue + vat_revenue

    # compute total HH income (not including capital)
    total_HH_income = tax_revenues + total_net_profits + w*total_labor_demand
    total_HH_income_noprofits = tax_revenues + w*total_labor_demand

    # closed-economy HH consumption 
    total_HH_consumption_closed = total_output - delta*productive_capital - productive_intermediates

    ### return objects ### 
    return (
        productive_capital = productive_capital, productive_labor = productive_labor, productive_intermediates = productive_intermediates, 
        total_output = total_output, total_labor_demand = total_labor_demand,
        tax_revenues = tax_revenues, total_net_profits = total_net_profits, 
        total_HH_income = total_HH_income, total_HH_income_noprofits = total_HH_income_noprofits, 
        total_HH_consumption_closed = total_HH_consumption_closed)
end 

# solve HH consumption/savings path 
function compute_HH_transition_path(; net_income_path, C_low, C_high, initial_assets, length_transition = 1000, crit = 1e-6, max_iter = 500)

    ### Need to solve for C & A_t path (shooting algorithm!)  
    
    # initialize objects 
    r = rental_rate - delta - 1.0
    A_path = ones(length_transition).*initial_assets 
    diff = Inf 
    iter = 1
    C_guess = (C_low + C_high)/2 

    # set shooting algorithm 
    while diff > crit && iter < max_iter

        # solve for A_path given C_guess 
        @inbounds for year = 2:length_transition
            A_path[year] = net_income_path[year - 1] + (1+r)*A_path[year-1] - C_guess 
        end 

        # update C_low or C_high (bisection) 
        println("Iter: ",iter," with C_low: ", C_low, ", C_high: ", C_high, ", Difference: ", C_high - C_low)
        if A_path[length_transition] > A_path[length_transition - 1] # exploding (so: too much savings)
            C_low = copy(C_guess) 
        elseif A_path[length_transition] < A_path[length_transition - 1] # shrinking (so: too little savings)
            C_high = copy(C_guess) 
        end 

        # update guess 
        C_guess = (C_low + C_high)/2 

        # compute new difference 
        diff = C_high - C_low
        iter = iter + 1 
    end 

    return (A_path = A_path, C = C_guess)
end 

# compute subsidy for DRS model 
function compute_subsidy_baseline_DRS(; z_star_grid, epsilon_grid, param)

    #### Step 1: Preparation ####

    # get parameters 
    theta = param.connect_theta 
    fixed_cost = param.connect_fixed_cost 

    # baseline x_star (enforcing w = 1.0) 
    x_star = ( (((1-vat_tax)*alpha_gross/(rental_rate - 1.0))^alpha_gross)*(((1-vat_tax)*beta_gross)^beta_gross)*((gamma_gross)^gamma_gross) )

    # get z_tilde_grid 
    z_tilde_grid = (z_star_grid .* x_star).^(1/(1-eta_tilde))
  
    # specify FOC 
    function rent_seeking_FOC(m_R, z_tilde_value, epsilon)
        return z_tilde_value .* ( (1.0 .+ (epsilon .*((m_R).^(theta)))).^elasticity_ratio ) .* theta .* epsilon .* ((m_R).^(theta - 1.0)) .- 1 
    end 

    # for lower bound of root finding: use something close to zero 
    lower_bound = 1e-32 

    # pick high number for upper bound 
    upper_bound = 1e+32
    
    #### Step 2: Solve for m_R* using FOCs ####

    # save results 
    optimal_m_R = zeros(length(epsilon_grid))

    # find root 
    @inbounds for row = 1:length(optimal_m_R)
        if epsilon_grid[row] <= 0.0 
            optimal_m_R[row] = 0.0 
        else 
            optimal_m_R[row] = Roots.find_zero(m_R -> rent_seeking_FOC(m_R,z_tilde_grid[row], epsilon_grid[row]), (lower_bound,upper_bound), Bisection())
        end 
    end 

    #### Step 3: Solve for remaining objects needed ####

    optimal_subsidy = epsilon_grid .* (optimal_m_R.^theta)
    optimal_revenue = z_tilde_grid .* ( (1.0 .+ optimal_subsidy ).^(elasticity_ratio .+ 1.0) )
    optimal_profits_C = (1-profit_tax) .* ((1.0 .- eta_tilde) .* optimal_revenue .- optimal_m_R)
    optimal_profits_NC = (1-profit_tax) .* (1.0 .- eta_tilde) .* z_tilde_grid
    TFPQ_model = (1.0 .+ optimal_subsidy) .* z_star_grid
    choose_C = (optimal_profits_C .- ((1-profit_tax) .* fixed_cost)) .>= optimal_profits_NC
    optimal_subsidy = ifelse.(choose_C,optimal_subsidy,0.0)
    optimal_m_R = ifelse.(choose_C, optimal_m_R, 0.0)
    optimal_revenue = ifelse.(choose_C, optimal_revenue, z_tilde_grid)
    optimal_profits = ifelse.(choose_C, optimal_profits_C .- ((1-profit_tax).* fixed_cost), optimal_profits_NC)

    #### Return final results ####
    return (
        z_star = z_star_grid, epsilon = epsilon_grid, z_tilde = z_tilde_grid,
        choose_C = choose_C, optimal_subsidy = optimal_subsidy, optimal_m_R = optimal_m_R,
        optimal_revenue = optimal_revenue, optimal_profits_C = optimal_profits_C, 
        optimal_profits_NC = optimal_profits_NC, optimal_profits = optimal_profits, 
        TFPQ_model = (1.0 .+ optimal_subsidy) .* z_star_grid, 
        revenue_output_C = z_tilde_grid .* ((1.0 .+ optimal_subsidy).^elasticity_ratio),
        revenue_output_NC = z_tilde_grid)
end 

# compute profits for DRS model 
function compute_profits_grid_baseline_DRS(; z_star_grid, n_eps, param)

    # baseline x_star (enforcing w = 1.0) 
    x_star = ( (((1-vat_tax)*alpha_gross/(rental_rate - 1.0))^alpha_gross)*(((1-vat_tax)*beta_gross)^beta_gross)*((gamma_gross)^gamma_gross) )

    # get z_tilde_grid 
    z_tilde_grid = (z_star_grid .* x_star).^(1/(1-eta_tilde))

    # length 
    n_rows = length(z_tilde_grid)

    # create results vectors 
    expected_profits_C = zeros(n_rows)
    expected_profits_NC = zeros(n_rows)
    expected_m_R_C = zeros(n_rows)
    expected_subsidies_C = zeros(n_rows)
    expected_revenue_output_C = zeros(n_rows)
    expected_revenue_output_NC = zeros(n_rows)
    expected_subsidy_rate_C = zeros(n_rows)
    expected_revenue_C = zeros(n_rows)
    expected_revenue_NC = copy(z_tilde_grid)
    subsidy_variation = zeros(n_rows*n_eps)
    epsilon_variation = zeros(n_rows*n_eps)
    epsilon_proba_variation = zeros(n_rows*n_eps)
    rent_seeking_share_variation = zeros(n_rows*n_eps)

    # for each row 
    @inbounds for row = 1:n_rows

        ### Substep 1: Draw epsilon conditional on z_star & the probability 
        Random.seed!(12345)
        normal_distrib = Normal(param.alpha_eps + param.beta_eps*log(z_star_grid[row]), sqrt(param.variance_eps_z))
        epsilon = rand(normal_distrib, n_eps)
        proba_epsilon = pdf.(normal_distrib, epsilon)

        # normalize probability to make sure it sums to 1  
        proba_epsilon = proba_epsilon ./ sum(proba_epsilon)

        ### Substep 2: Get optimal profits 

        results = compute_subsidy_baseline_DRS(
            z_star_grid = ones(n_eps) .* z_star_grid[row], 
            epsilon_grid = epsilon, 
            param = param)

        # get expected profits by taking weighted mean across all possible epsilon 
        expected_profits_C[row] = sum(results.optimal_profits_C .* proba_epsilon)
        expected_profits_NC[row] = sum(results.optimal_profits_NC .* proba_epsilon)
  
        # collect further objects needed for later 
        expected_m_R_C[row] = sum(results.optimal_m_R .* proba_epsilon)
        expected_subsidies_C[row] = sum(results.optimal_revenue .* (results.optimal_subsidy ./ (1.0 .+ results.optimal_subsidy)) .* proba_epsilon)
        expected_revenue_output_C[row] = sum(results.revenue_output_C .* proba_epsilon)
        expected_revenue_output_NC[row] = sum(results.revenue_output_NC .* proba_epsilon)
        expected_subsidy_rate_C[row] = sum(results.optimal_subsidy .* proba_epsilon)
        expected_revenue_C[row] = sum(results.optimal_revenue .* proba_epsilon)

        # also save all subsidy variation 
        subsidy_variation[((row-1)*n_eps + 1):(row*n_eps)] = results.optimal_subsidy
        epsilon_variation[((row-1)*n_eps + 1):(row*n_eps)] = epsilon
        epsilon_proba_variation[((row-1)*n_eps + 1):(row*n_eps)] = proba_epsilon
        rent_seeking_share_variation[((row-1)*n_eps + 1):(row*n_eps)] = results.optimal_m_R ./ (results.optimal_m_R .+ (gamma_gross .* results.optimal_revenue))
    end 

    expected_profits = param.proba_c .* expected_profits_C .+ (1.0 - param.proba_c) .* expected_profits_NC

    # adjust all variables by proba_c ? 
    expected_subsidies = param.proba_c .* expected_subsidies_C
    expected_revenue_output = param.proba_c .* expected_revenue_output_C .+ (1.0 - param.proba_c) .* expected_revenue_output_NC
    expected_revenue = param.proba_c .* expected_revenue_C .+ (1.0 - param.proba_c) .* expected_revenue_NC      

    return (
        expected_profits_C = expected_profits_C,
        expected_profits_NC = expected_profits_NC,
        expected_profits = expected_profits, 
        expected_m_R_C = expected_m_R_C,
        expected_subsidies_C = expected_subsidies_C,
        expected_subsidies = expected_subsidies,  
        expected_revenue_output_C = expected_revenue_output_C,
        expected_revenue_output_NC = expected_revenue_output_NC,
        expected_revenue_output = expected_revenue_output, 
        expected_subsidy_rate_C = expected_subsidy_rate_C,
        expected_revenue_C = expected_revenue_C,
        expected_revenue_NC = expected_revenue_NC, 
        expected_revenue = expected_revenue,
        subsidy_variation = subsidy_variation, 
        epsilon_variation = epsilon_variation, 
        epsilon_proba_variation = epsilon_proba_variation,
        rent_seeking_share_variation = rent_seeking_share_variation
        )
end

# 
function compute_aggregates_noC_noEE(; w, distrib, mass_firms, within_period_choices, tax = nothing) 

    if tax == nothing
        tax = copy(vat_tax) 
    end 

    # Get total (productive) inputs 
    productive_capital = sum( within_period_choices.expected_revenue .* alpha_gross .* (1-tax) .* distrib ./ (rental_rate - 1) ) * mass_firms
    productive_labor = sum( within_period_choices.expected_revenue .* beta_gross .* (1-tax) .* distrib ./ w ) * mass_firms
    productive_intermediates = sum( within_period_choices.expected_revenue .* gamma_gross .* distrib ) * mass_firms

    # Get total output
    total_output = sum( within_period_choices.expected_revenue_output .* distrib ) * mass_firms
    total_net_profits = sum( within_period_choices.expected_profits .* distrib ) * mass_firms

    # Can back out total labor supply then 
    total_labor_demand = productive_labor

    # total tax revenues
    cit_revenue = sum( within_period_choices.expected_profits .* distrib) * (profit_tax/(1-profit_tax)) * mass_firms
    vat_revenue = sum( tax .* (1-gamma_gross) .* within_period_choices.expected_revenue .* distrib) * mass_firms
    tax_revenues = cit_revenue + vat_revenue

    # compute total HH income (not including capital)
    total_HH_income = tax_revenues + total_net_profits + w*total_labor_demand
    total_HH_income_noprofits = tax_revenues + w*total_labor_demand

    # closed-economy HH consumption 
    total_HH_consumption_closed = total_output - delta*productive_capital - productive_intermediates

    ### return objects ### 
    return (
        productive_capital = productive_capital, productive_labor = productive_labor, productive_intermediates = productive_intermediates, 
        total_output = total_output, total_labor_demand = total_labor_demand,
        tax_revenues = tax_revenues, total_net_profits = total_net_profits, 
        total_HH_income = total_HH_income, total_HH_income_noprofits = total_HH_income_noprofits, 
        total_HH_consumption_closed = total_HH_consumption_closed)
end 

